package log

import (
	"fmt"
	"os"

	"github.com/go-kit/log"
	"github.com/go-kit/log/level"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/common/promslog"
	"github.com/weaveworks/common/logging"
	"github.com/weaveworks/common/server"
)

var (
	// Logger is a shared go-kit logger.
	// TODO: Change all components to take a non-global logger via their constructors.
	// Prefer accepting a non-global logger as an argument.
	Logger  = log.NewNopLogger()
	SLogger = promslog.NewNopLogger()

	logMessages = prometheus.NewCounterVec(prometheus.CounterOpts{
		Name: "log_messages_total",
		Help: "Total number of log messages.",
	}, []string{"level"})

	supportedLevels = []level.Value{
		level.DebugValue(),
		level.InfoValue(),
		level.WarnValue(),
		level.ErrorValue(),
	}
)

func init() {
	prometheus.MustRegister(logMessages)
}

// InitLogger initialises the global gokit logger (util_log.Logger) and overrides the
// default logger for the server.
func InitLogger(cfg *server.Config) {
	l := newLoggerWithFormat(cfg.LogFormat)

	// when use util_log.Logger, skip 6 stack frames.
	Logger = newPrometheusLoggerFrom(l, cfg.LogLevel, "caller", log.Caller(6))
	SLogger = GoKitLogToSlog(Logger)

	// cfg.Log wraps log function, skip 7 stack frames to get caller information.
	// this works in go 1.12, but doesn't work in versions earlier.
	// it will always shows the wrapper function generated by compiler
	// marked <autogenerated> in old versions.
	cfg.Log = logging.GoKit(newPrometheusLoggerFrom(l, cfg.LogLevel, "caller", log.Caller(7)))
}

// PrometheusLogger exposes Prometheus counters for each of go-kit's log levels.
type PrometheusLogger struct {
	logger   log.Logger
	logLevel logging.Level
}

// NewPrometheusLogger creates a new instance of PrometheusLogger which exposes
// Prometheus counters for various log levels.
func NewPrometheusLogger(l logging.Level, format logging.Format) (log.Logger, error) {
	logger := newLoggerWithFormat(format)
	return newPrometheusLoggerFrom(logger, l), nil
}

func newLoggerWithFormat(format logging.Format) log.Logger {
	logger := log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr))
	if format.String() == "json" {
		logger = log.NewJSONLogger(log.NewSyncWriter(os.Stderr))
	}
	return logger
}

func newPrometheusLoggerFrom(logger log.Logger, logLevel logging.Level, keyvals ...any) log.Logger {
	// Sort the logger chain to avoid expensive log.Valuer evaluation for disallowed level.
	// Ref: https://github.com/go-kit/log/issues/14#issuecomment-945038252
	logger = log.With(logger, "ts", log.DefaultTimestampUTC)
	logger = log.With(logger, keyvals...)
	logger = level.NewFilter(logger, logLevel.Gokit)

	// Initialise counters for all supported levels:
	for _, level := range supportedLevels {
		logMessages.WithLabelValues(level.String())
	}
	return &PrometheusLogger{
		logger:   logger,
		logLevel: logLevel,
	}
}

// Log increments the appropriate Prometheus counter depending on the log level.
func (pl *PrometheusLogger) Log(kv ...any) error {
	pl.logger.Log(kv...)
	l := "unknown"
	for i := 1; i < len(kv); i += 2 {
		if v, ok := kv[i].(level.Value); ok {
			l = v.String()
			break
		}
	}
	logMessages.WithLabelValues(l).Inc()
	return nil
}

// CheckFatal prints an error and exits with error code 1 if err is non-nil
func CheckFatal(location string, err error) {
	if err != nil {
		logger := level.Error(Logger)
		if location != "" {
			logger = log.With(logger, "msg", "error "+location)
		}
		// %+v gets the stack trace from errors using github.com/pkg/errors
		logger.Log("err", fmt.Sprintf("%+v", err))
		os.Exit(1)
	}
}
